import torch
from bgflow.utils import as_numpy
import matplotlib.pyplot as plt


def superpose_points(points1, points2):

    # Compute optimal rotation and translation
    M = torch.matmul(points1.t(), points2)
    U, S, V = torch.svd(M)
    R = torch.matmul(U, V.t())

    # Apply rotation and translation to superpose points1 onto points2
    superposed_points1 = torch.matmul(points1, R)

    return superposed_points1


def superpose_points_batch(points, reference):

    # Compute optimal rotation and translation
    M = torch.matmul(points.transpose(-2, -1), reference)
    U, S, V = torch.svd(M)
    R = torch.matmul(U, V.transpose(-2, -1))
    # Apply rotation and translation to superpose points1 onto points2
    superposed_points = torch.matmul(points, R)

    return superposed_points


def remove_mean(x):
    mean = torch.mean(x, dim=1, keepdim=True)
    x = x - mean
    return x


def remove_mean_with_mask(x, node_mask):
    assert (x * (1 - node_mask)).abs().sum().item() < 1e-8
    N = node_mask.sum(1, keepdims=True)

    mean = torch.sum(x, dim=1, keepdim=True) / N
    x = x - mean * node_mask
    return x


def plot_flowpath_trajectory(traj, n_dimensions=2):
    plt.figure(figsize=(9, 9))
    latent_sample = as_numpy(traj[0].reshape(-1, n_dimensions))
    target_sample = as_numpy(traj[-1].reshape(-1, n_dimensions))
    plt.scatter(*latent_sample.T, alpha=0.95, label="latent", s=100)
    traj = as_numpy(traj)
    plt.scatter(
        traj[:, :, 0].flatten(),
        traj[:, :, 1].flatten(),
        color="black",
        s=10,
        label="path",
    )
    plt.scatter(*target_sample.T, alpha=0.95, label="target", s=100)
    plt.legend()
    plt.title("Flow path", fontsize=45)
    plt.xticks(fontsize=45)
    plt.yticks(fontsize=45)
    plt.legend(fontsize=25)


def plot_flowpath_trajectory_3d(traj, n_dimensions=3):
    fig = plt.figure(figsize=(9, 9))
    ax = fig.add_subplot(projection="3d")
    latent_sample = as_numpy(traj[0].reshape(-1, n_dimensions))
    target_sample = as_numpy(traj[-1].reshape(-1, n_dimensions))
    ax.scatter(*latent_sample.T, alpha=0.95, label="latent", s=100)

    traj = as_numpy(traj)
    ax.scatter(
        traj[:, :, 0].flatten(),
        traj[:, :, 1].flatten(),
        traj[:, :, 2].flatten(),
        color="black",
        s=10,
        label="path",
    )
    ax.scatter(*target_sample.T, alpha=0.95, label="target", s=100)
    ax.set_xlim((-1.5, 1.5))
    ax.set_ylim((-1.5, 1.5))
    ax.set_zlim((-1.5, 1.5))


def create_adjacency_list(distance_matrix, atom_types):
    adjacency_list = []

    # Iterate through the distance matrix
    num_nodes = len(distance_matrix)
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):  # Avoid duplicate pairs
            distance = distance_matrix[i][j]
            element_i = atom_types[i]
            element_j = atom_types[j]
            if 1 in (element_i, element_j):
                distance_cutoff = 0.14
            elif 4 in (element_i, element_j):
                distance_cutoff = 0.22
            elif 0 in (element_i, element_j):
                distance_cutoff = 0.18
            else:
                # elements should not be bonded
                distance_cutoff = 0.0

            # Add edge if distance is below the cutoff
            if distance < distance_cutoff:
                adjacency_list.append([i, j])

    return adjacency_list


# chekc if chirality is the same
# if not --> mirror
# if still not --> discard
def find_chirality_centers(
    adj_list: torch.Tensor, atom_types: torch.Tensor, num_h_atoms: int = 2
) -> torch.Tensor:
    """
    Return the chirality centers for a peptide, e.g. carbon alpha atoms and their bonds.

    Args:
        adj_list: List of bonds
        atom_types: List of atom types
        num_h_atoms: If num_h_atoms or more hydrogen atoms connected to the center, it is not reportet.
            Default is 2, because in this case the mirroring is a simple permutation.

    Returns:
        chirality_centers
    """
    chirality_centers = []
    candidate_chirality_centers = torch.where(
        torch.unique(adj_list, return_counts=True)[1] == 4
    )[0]
    for center in candidate_chirality_centers:
        bond_idx, bond_pos = torch.where(adj_list == center)
        bonded_idxs = adj_list[bond_idx, (bond_pos + 1) % 2].long()
        adj_types = atom_types[bonded_idxs]
        if torch.count_nonzero(adj_types - 1) > num_h_atoms:
            chirality_centers.append([center, *bonded_idxs[:3]])
    return torch.tensor(chirality_centers).to(adj_list).long()


def compute_chirality_sign(
    coords: torch.Tensor, chirality_centers: torch.Tensor
) -> torch.Tensor:
    """
    Compute indicator signs for a given configuration.
    If the signs for two configurations are different for the same center, the chirality changed.

    Args:
        coords: Tensor of atom coordinates
        chirality_centers: List of chirality_centers

    Returns:
        Indicator signs
    """
    assert coords.dim() == 3
    # print(coords.shape, chirality_centers.shape, chirality_centers)
    direction_vectors = (
        coords[:, chirality_centers[:, 1:], :] - coords[:, chirality_centers[:, [0]], :]
    )
    perm_sign = torch.einsum(
        "ijk, ijk->ij",
        direction_vectors[:, :, 0],
        torch.cross(direction_vectors[:, :, 1], direction_vectors[:, :, 2], dim=-1),
    )
    return torch.sign(perm_sign)


def check_symmetry_change(
    coords: torch.Tensor, chirality_centers: torch.Tensor, reference_signs: torch.Tensor
) -> torch.Tensor:
    """
    Check for a batch if the chirality changed wrt to some reference reference_signs.
    If the signs for two configurations are different for the same center, the chirality changed.

    Args:
        coords: Tensor of atom coordinates
        chirality_centers: List of chirality_centers
        reference_signs: List of reference sign for the chirality_centers
    Returns:
        Mask, where changes are True
    """
    perm_sign = compute_chirality_sign(coords, chirality_centers)
    return (perm_sign != reference_signs.to(coords)).any(dim=-1)
